from abc import ABC, abstractmethod
from torch import nn
import torch

class BaseDetector(ABC, nn.Module):
    recurrent = False
    params_free = False
    def __init__(self) -> None:
        super().__init__()
        pass

    @abstractmethod
    def forward(self, obs):
        """
        Returns probabilities over rm states (normalized)
        """
        pass

class PerfectDetector(BaseDetector):
    params_free = True
    def __init__(self, obs_space):
        super().__init__()
        self.out_dim = obs_space['rm_state']

    def forward(self, obs):
        return obs.rm_state


class SimpleDetectorModel(BaseDetector):
    recurrent = False
    def __init__(self, obs_space, out_dim):
        super().__init__()

        n, m, k = obs_space['image']
        self.out_dim = out_dim

        self.net = nn.Sequential(
            nn.Conv2d(k, 16, (2, 2)),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(16, 32, (2, 2)),
            nn.ReLU(),
            nn.Conv2d(32, 64, (2, 2)),
            nn.ReLU(),
            nn.Flatten()
        )
        im_dim = ((n-1)//2-2)*((m-1)//2-2)*64
        self.output = nn.Linear(im_dim, self.out_dim)

    def forward(self, obs):
        x = obs.image.transpose(1, 3).transpose(2, 3)
        out = self.output(self.net(x)).view((-1, self.out_dim))
        return out

class RecurrentDectectorModel(BaseDetector):
    recurrent = True

    def __init__(self, obs_space, out_dim) -> None:
        super().__init__()

        n, m, k = obs_space['image']
        self.out_dim = out_dim
        self.image_embedding_size = 16

        self.embedding_size = self.semi_memory_size

        # image -> embedding
        self.encoder = nn.Sequential(
            nn.Conv2d(k, 16, (2, 2)),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(16, 32, (2, 2)),
            nn.ReLU(),
            nn.Conv2d(32, self.image_embedding_size, (2, 2)),
            nn.ReLU()
        )

        # embedding -> embedding
        self.memory_rnn = nn.LSTMCell(self.image_embedding_size, self.semi_memory_size)

        # embedding -> rm states
        self.decoder = nn.Sequential(
            nn.Linear(self.embedding_size, self.embedding_size//2),
            nn.ReLU(),
            nn.Linear(self.embedding_size//2, self.out_dim)
        )

    @property
    def memory_size(self):
        return 2*self.semi_memory_size

    @property
    def semi_memory_size(self):
        return self.image_embedding_size

    def forward(self, obs, memory):
        x = obs.image.transpose(1, 3).transpose(2, 3)
        input_embed = self.encoder(x).view(-1, self.image_embedding_size)
        hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:])

        hidden = self.memory_rnn(input_embed, hidden)
        memory = torch.cat(hidden, dim=1)
        output_embed = hidden[0]
        out = self.decoder(output_embed)

        out = out.view((-1, self.out_dim))
        return out, memory
